{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Mnist分类任务：\n",
    "\n",
    "- 网络基本构建与训练方法，常用函数解析\n",
    "\n",
    "- torch.nn.functional模块\n",
    "\n",
    "- nn.Module模块\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 读取Mnist数据集\n",
    "- 会自动进行下载"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 49,
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "2.2.2+cu121\n",
      "True\n"
     ]
    }
   ],
   "source": [
    "import torch\n",
    "from matplotlib import pyplot\n",
    "import pickle\n",
    "import gzip\n",
    "from pathlib import Path\n",
    "import requests\n",
    "\n",
    "print(torch.__version__)\n",
    "print(torch.cuda.is_available())"
   ],
   "metadata": {
    "collapsed": false,
    "ExecuteTime": {
     "end_time": "2024-04-07T16:35:08.601004900Z",
     "start_time": "2024-04-07T16:35:08.591860900Z"
    }
   }
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "outputs": [],
   "source": [
    "# %matplotlib inline"
   ],
   "metadata": {
    "collapsed": false,
    "ExecuteTime": {
     "end_time": "2024-04-06T15:16:23.401503100Z",
     "start_time": "2024-04-06T15:16:23.399177400Z"
    }
   }
  },
  {
   "cell_type": "markdown",
   "source": [
    "# 1.数据准备"
   ],
   "metadata": {
    "collapsed": false
   }
  },
  {
   "cell_type": "markdown",
   "source": [
    "## 1.1 下载数据"
   ],
   "metadata": {
    "collapsed": false
   }
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {
    "collapsed": true,
    "ExecuteTime": {
     "end_time": "2024-04-07T15:41:14.412891200Z",
     "start_time": "2024-04-07T15:41:11.150193700Z"
    }
   },
   "outputs": [],
   "source": [
    "# 定义及创建 下载路径\n",
    "DATA_PATH = Path(\"data\")\n",
    "PATH = DATA_PATH / \"mnist\"\n",
    "PATH.mkdir(parents=True, exist_ok=True)\n",
    "URL = \"http://deeplearning.net/data/mnist/\"\n",
    "FILENAME = \"mnist.pkl.gz\"\n",
    "# 下载数据\n",
    "if not (PATH / FILENAME).exists():\n",
    "    content = requests.get(URL + FILENAME).content\n",
    "    (PATH / FILENAME).open(\"wb\").write(content)"
   ]
  },
  {
   "cell_type": "markdown",
   "source": [
    "## 1.2 读取数据"
   ],
   "metadata": {
    "collapsed": false
   }
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {
    "collapsed": true,
    "ExecuteTime": {
     "end_time": "2024-04-07T15:46:24.659369200Z",
     "start_time": "2024-04-07T15:46:24.246988500Z"
    }
   },
   "outputs": [
    {
     "data": {
      "text/plain": "(784,)"
     },
     "execution_count": 14,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# 读取数据\n",
    "with gzip.open((PATH / FILENAME).as_posix(), \"rb\") as f:\n",
    "    ((x_train, y_train), (x_valid, y_valid), _) = pickle.load(f, encoding=\"latin-1\")\n",
    "# 拿出一个数据看看\n",
    "x_train[0].shape"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "784是mnist数据集每个样本的像素点个数"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2024-04-07T15:47:27.482664300Z",
     "start_time": "2024-04-07T15:47:27.368994700Z"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "(50000, 784)\n"
     ]
    },
    {
     "data": {
      "text/plain": "<Figure size 640x480 with 1 Axes>",
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAZ0AAAGbCAYAAAAfhk2/AAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjguNCwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8fJSN1AAAACXBIWXMAAA9hAAAPYQGoP6dpAAAZEklEQVR4nO3dX2xT9/nH8Y8TgyFATIJKUNdUDf8y0SlEQllTulFNpWo3zFjGNjUtQmJLGdNKJm2y1ExMaBdNkDaQVi2dGO1AQ2LqVQstQmulNNrF0opQ/gUjJpgZWQhSswU7IeAUfH4XqP41DbB8HfPYjt8v6Vz4xI/9nXfm945xTnye53kCAMBAUbYXAAAoHEQHAGCG6AAAzBAdAIAZogMAMEN0AABmiA4AwAzRAQCY8Wd7AZKUTCZ1+fJlzZkzRz6fL9vLAQA48DxPQ0NDevDBB1VUdO9zmZyIzuXLl1VZWZntZQAAJqG3t1cPPfTQPe+TEx+vzZkzJ9tLAABM0kTeyzManZ6eHtXV1amsrEzhcFgTvawbH6kBQP6byHt5xqKTSCS0du1arVixQt3d3YpEItq3b1+mHh4AMBV4GfLWW295ZWVl3rVr1zzP87wTJ054TzzxxIRmY7GYJ4mNjY2NLY+3WCz2P9/vM/ZFgpMnT6q+vl4lJSWSpJqaGkUikTveN5FIKJFIpG7H4/FMLQMAkMMy9vFaPB5XVVVV6rbP51NxcbEGBwfH3betrU3BYDC18c01ACgMGYuO3+9XIBAYs2/GjBkaGRkZd9+WlhbFYrHU1tvbm6llAAByWMY+XisvL1dPT8+YfUNDQ5o+ffq4+wYCgXGBAgBMfRk706mrq1NXV1fqdjQaVSKRUHl5eaaeAgCQ5zIWnVWrVikej2vv3r2SpNbWVq1evVrFxcWZegoAQJ7zed4Ef4NzAg4dOqTGxkbNnDlTRUVF6uzs1LJly/7nXDweVzAYzNQyAABZEIvFVFpaes/7ZDQ6knTlyhUdO3ZM9fX1mjdv3oRmiA4A5L+sRCcdRAcA8t9EopMTF/wEABQGogMAMEN0AABmiA4AwAzRAQCYIToAADNEBwBghugAAMwQHQCAGaIDADBDdAAAZogOAMAM0QEAmCE6AAAzRAcAYIboAADMEB0AgBmiAwAwQ3QAAGaIDgDADNEBAJghOgAAM0QHAGCG6AAAzBAdAIAZogMAMEN0AABmiA4AwAzRAQCYIToAADNEBwBghugAAMwQHQCAGaIDADBDdAAAZogOAMAM0QEAmCE6AAAzRAcAYIboAADMEB0AgBmiAwAwQ3QAAGaIDgDADNEBAJghOgAAM0QHAGCG6AAAzBAdAIAZogMAMEN0AABm/NleAJBtxcXFzjPBYPA+rCQzXnrppbTmSkpKnGeqq6udZ3760586z/z2t791nmlsbHSekaQbN244z+zYscN55te//rXzzFTAmQ4AwAzRAQCYyVh0mpub5fP5UtvixYsz9dAAgCkiY/+m093drcOHD2vlypWS0vucHAAwtWUkOjdv3tSZM2e0atUqzZ49OxMPCQCYgjLy8drp06eVTCZVW1urmTNn6tlnn9WlS5fuev9EIqF4PD5mAwBMfRmJTiQSUXV1tfbv369Tp07J7/dr8+bNd71/W1ubgsFgaqusrMzEMgAAOS4j0XnhhRfU3d2txx9/XEuWLNFrr72m999//65nMC0tLYrFYqmtt7c3E8sAAOS4+/LLofPnz1cymVR/f79KS0vH/TwQCCgQCNyPpwYA5LCMnOmEw2EdOHAgdburq0tFRUV8bAYAGCMjZzrLly/Xtm3bVFFRoVu3bmnr1q3auHFjWpfVAABMXRmJzoYNG3TmzBmtX79excXF2rBhg1pbWzPx0ACAKcTneZ6X7UXE4/GcvoAi/t/DDz/sPDN9+nTnmc9+ydjF1772NecZSZo7d67zzPr169N6rqnm3//+t/PM0aNHnWcaGhqcZ65du+Y8I0knT550nvnVr37lPNPZ2ek8k+tisdgd/x3/87j2GgDADNEBAJghOgAAM0QHAGCG6AAAzBAdAIAZogMAMEN0AABmiA4AwAzRAQCYIToAADNEBwBghgt+Fqja2tq05jo6Opxn+O82PySTSeeZH/7wh84zw8PDzjPp6O/vT2tucHDQeebcuXNpPddUwwU/AQA5hegAAMwQHQCAGaIDADBDdAAAZogOAMAM0QEAmCE6AAAzRAcAYIboAADMEB0AgBmiAwAwQ3QAAGb82V4AsuPSpUtpzf3nP/9xnuEq07d99NFHzjNXr151nvnGN77hPCNJo6OjzjP79+9P67lQuDjTAQCYIToAADNEBwBghugAAMwQHQCAGaIDADBDdAAAZogOAMAM0QEAmCE6AAAzRAcAYIboAADMcMHPAvXf//43rblwOOw8EwqFnGeOHz/uPPPqq686z6TrxIkTzjNPP/2088y1a9ecZx599FHnGUn62c9+ltYc4IIzHQCAGaIDADBDdAAAZogOAMAM0QEAmCE6AAAzRAcAYIboAADMEB0AgBmiAwAwQ3QAAGaIDgDAjM/zPC/bi4jH4woGg9leBu6T0tJS55mhoSHnmd27dzvPSNKPfvQj55kNGzY4z/zlL39xngHySSwW+5//e+dMBwBghugAAMw4R2dgYEBVVVW6ePFial9PT4/q6upUVlamcDisHPjEDgCQg5yiMzAwoFAoNCY4iURCa9eu1YoVK9Td3a1IJKJ9+/ZleJkAgKnAKTrPPfecnn/++TH7jhw5olgspl27dmnRokVqbW3VG2+8kdFFAgCmBqc/V71nzx5VVVWN+bO2J0+eVH19vUpKSiRJNTU1ikQi93ycRCKhRCKRuh2Px12WAQDIU05nOlVVVeP2xePxMft9Pp+Ki4s1ODh418dpa2tTMBhMbZWVlS7LAADkqUl/e83v9ysQCIzZN2PGDI2MjNx1pqWlRbFYLLX19vZOdhkAgDzg9PHanZSXl6unp2fMvqGhIU2fPv2uM4FAYFyoAABT36TPdOrq6tTV1ZW6HY1GlUgkVF5ePtmHBgBMMZOOzqpVqxSPx7V3715JUmtrq1avXq3i4uJJLw4AMLVM+uM1v9+v119/XY2NjQqHwyoqKlJnZ2cGlgYAmGrSis4Xrzjw7W9/WxcuXNCxY8dUX1+vefPmZWRxmBqsvhIfi8VMnkeSXnzxReeZN99803kmmUw6zwC5bNJnOp9ZsGCB1qxZk6mHAwBMQVzwEwBghugAAMwQHQCAGaIDADBDdAAAZogOAMAM0QEAmCE6AAAzRAcAYIboAADMEB0AgBmiAwAw4/O+eMnoLIjH4woGg9leBvLcrFmz0pp75513nGeefPJJ55lvfvObzjPvvfee8wyQLbFYTKWlpfe8D2c6AAAzRAcAYIboAADMEB0AgBmiAwAwQ3QAAGaIDgDADNEBAJghOgAAM0QHAGCG6AAAzBAdAIAZLviJgrdo0SLnmY8//th55urVq84zH3zwgfNMd3e384wktbe3O8/kwNsHcggX/AQA5BSiAwAwQ3QAAGaIDgDADNEBAJghOgAAM0QHAGCG6AAAzBAdAIAZogMAMEN0AABmiA4AwAwX/ATS0NDQ4Dyzd+9e55k5c+Y4z6Trl7/8pfPMn//8Z+eZ/v5+5xnkBy74CQDIKUQHAGCG6AAAzBAdAIAZogMAMEN0AABmiA4AwAzRAQCYIToAADNEBwBghugAAMwQHQCAGS74CRj5yle+4jyza9cu55mnnnrKeSZdu3fvdp555ZVXnGf6+vqcZ2CPC34CAHIK0QEAmHGOzsDAgKqqqnTx4sXUvubmZvl8vtS2ePHiTK4RADBF+F3uPDAwoFAoNCY4ktTd3a3Dhw9r5cqVkqTi4uKMLRAAMHU4nek899xzev7558fsu3nzps6cOaNVq1Zp7ty5mjt3rulfOwQA5A+n6OzZs0fNzc1j9p0+fVrJZFK1tbWaOXOmnn32WV26dOmej5NIJBSPx8dsAICpzyk6VVVV4/ZFIhFVV1dr//79OnXqlPx+vzZv3nzPx2lra1MwGExtlZWVbqsGAOSlSX977YUXXlB3d7cef/xxLVmyRK+99pref//9e569tLS0KBaLpbbe3t7JLgMAkAecvkgwEfPnz1cymVR/f/9df0koEAgoEAhk+qkBADlu0mc64XBYBw4cSN3u6upSUVERH5kBAMaZ9JnO8uXLtW3bNlVUVOjWrVvaunWrNm7cqJKSkkysDwAwhUw6Ohs2bNCZM2e0fv16FRcXa8OGDWptbc3E2gAAUwwX/ARy2Ny5c51n1q5dm9Zz7d2713nG5/M5z3R0dDjPPP30084zsMcFPwEAOYXoAADMEB0AgBmiAwAwQ3QAAGaIDgDADNEBAJghOgAAM0QHAGCG6AAAzBAdAIAZogMAMEN0AABmuMo0AElSIpFwnvH73f86ys2bN51nnnnmGeeZzs5O5xlMDleZBgDkFKIDADBDdAAAZogOAMAM0QEAmCE6AAAzRAcAYIboAADMEB0AgBmiAwAwQ3QAAGaIDgDAjPvV+gCkpaamxnnme9/7nvNMXV2d84yU3sU70xGJRJxn/va3v92HlSAbONMBAJghOgAAM0QHAGCG6AAAzBAdAIAZogMAMEN0AABmiA4AwAzRAQCYIToAADNEBwBghugAAMxwwU8UvOrqaueZl156yXnmu9/9rvPMggULnGcs3bp1y3mmv7/feSaZTDrPIDdxpgMAMEN0AABmiA4AwAzRAQCYIToAADNEBwBghugAAMwQHQCAGaIDADBDdAAAZogOAMAM0QEAmOGCn8hJ6VzosrGxMa3nSufinY888khaz5XLuru7nWdeeeUV55lDhw45z2Dq4EwHAGCG6AAAzDhH5+DBg1q4cKH8fr9qa2t19uxZSVJPT4/q6upUVlamcDgsz/MyvlgAQH5zis6FCxe0adMm7dixQ319fVq6dKmampqUSCS0du1arVixQt3d3YpEItq3b999WjIAIF85Refs2bPasWOHfvCDH6iiokI/+clPdPz4cR05ckSxWEy7du3SokWL1NraqjfeeON+rRkAkKecvr0WCoXG3D537pyWLFmikydPqr6+XiUlJZKkmpoaRSKRuz5OIpFQIpFI3Y7H4y7LAADkqbS/SDA6OqqdO3dqy5YtisfjqqqqSv3M5/OpuLhYg4ODd5xta2tTMBhMbZWVlekuAwCQR9KOzvbt2zVr1iw1NTXJ7/crEAiM+fmMGTM0MjJyx9mWlhbFYrHU1tvbm+4yAAB5JK1fDu3o6FB7e7s+/PBDTZs2TeXl5erp6Rlzn6GhIU2fPv2O84FAYFykAABTn/OZTjQaVWNjo9rb27Vs2TJJUl1dnbq6usbcJ5FIqLy8PHMrBQDkPafoXL9+XaFQSOvWrVNDQ4OGh4c1PDysr3/964rH49q7d68kqbW1VatXr1ZxcfF9WTQAID85fbz23nvvKRKJKBKJaM+ePan90WhUr7/+uhobGxUOh1VUVKTOzs5MrxUAkOd8XgYvHXDlyhUdO3ZM9fX1mjdv3oTn4vG4gsFgppaB+6iiosJ55rOPYV38/ve/d5758pe/7DyT6z766CPnmd/85jdpPdfBgwedZ5LJZFrPhakpFouptLT0nvfJ6FWmFyxYoDVr1mTyIQEAUwgX/AQAmCE6AAAzRAcAYIboAADMEB0AgBmiAwAwQ3QAAGaIDgDADNEBAJghOgAAM0QHAGCG6AAAzGT0gp/IjnT+WN7u3bvTeq7a2lrnmYULF6b1XLns73//u/PMzp07nWf++te/Os9cv37deQawwpkOAMAM0QEAmCE6AAAzRAcAYIboAADMEB0AgBmiAwAwQ3QAAGaIDgDADNEBAJghOgAAM0QHAGCGC37eR4899pjzTDgcdp756le/6jzzpS99yXkm142MjKQ19+qrrzrPtLa2Os9cu3bNeQaYajjTAQCYIToAADNEBwBghugAAMwQHQCAGaIDADBDdAAAZogOAMAM0QEAmCE6AAAzRAcAYIboAADMcMHP+6ihocFkxlIkEnGeeffdd51nbt686Tyzc+dO5xlJunr1alpzANxxpgMAMEN0AABmiA4AwAzRAQCYIToAADNEBwBghugAAMwQHQCAGaIDADBDdAAAZogOAMAM0QEAmPF5nudlexHxeFzBYDDbywAATEIsFlNpaek978OZDgDADNEBAJhxis7Bgwe1cOFC+f1+1dbW6uzZs5Kk5uZm+Xy+1LZ48eL7slgAQH6bcHQuXLigTZs2aceOHerr69PSpUvV1NQkSeru7tbhw4c1ODiowcFBHT9+/L4tGACQvyb8RYJ3331Xly9f1ubNmyVJH3zwgdasWaN4PK558+apr69Ps2fPTmsRfJEAAPLfRL5IMOE/Vx0KhcbcPnfunJYsWaLTp08rmUyqtrZWfX19evLJJ/XHP/5RDz/88F0fK5FIKJFIpG7H4/GJLgMAkMfS+iLB6Oiodu7cqS1btigSiai6ulr79+/XqVOn5Pf7U2dDd9PW1qZgMJjaKisr01o8ACC/pPV7Oi0tLTpy5IiOHj2qadOmjfnZpUuXVFVVpcHBwbueZt3pTIfwAEB+y+jHa5/p6OhQe3u7Pvzww3HBkaT58+crmUyqv7//rk8eCAQUCARcnxoAkOecPl6LRqNqbGxUe3u7li1bJkkKh8M6cOBA6j5dXV0qKirizAUAMM6Ez3SuX7+uUCikdevWqaGhQcPDw5Kkmpoabdu2TRUVFbp165a2bt2qjRs3qqSk5L4tGgCQp7wJevvttz1J47ZoNOq9/PLLXjAY9MrLy73m5mZveHh4og/reZ7nxWKxOz42GxsbG1v+bLFY7H++33PBTwBARnDBTwBATiE6AAAzRAcAYIboAADMEB0AgBmiAwAwQ3QAAGaIDgDADNEBAJghOgAAM0QHAGCG6AAAzBAdAIAZogMAMEN0AABmiA4AwAzRAQCYIToAADNEBwBghugAAMwQHQCAGaIDADBDdAAAZogOAMAM0QEAmMmJ6Hiel+0lAAAmaSLv5TkRnaGhoWwvAQAwSRN5L/d5OXCakUwmdfnyZc2ZM0c+ny+1Px6Pq7KyUr29vSotLc3iCrOL1+E2XofbeB1u43W4LRdeB8/zNDQ0pAcffFBFRfc+l/EbremeioqK9NBDD93156WlpQV9UH2G1+E2XofbeB1u43W4LduvQzAYnND9cuLjNQBAYSA6AAAzOR2dQCCg7du3KxAIZHspWcXrcBuvw228DrfxOtyWb69DTnyRAABQGHL6TAcAMLUQHQCAGaIDADBDdHJcc3OzfD5falu8eHG2l4QsGBgYUFVVlS5evJjax7FRmA4ePKiFCxfK7/ertrZWZ8+elZQ/x0PORqenp0d1dXUqKytTOBwu2OuzdXd36/DhwxocHNTg4KCOHz+e7SWZutObbaEdGwMDAwqFQmNeA6nwjo27vdkW0vFw4cIFbdq0STt27FBfX5+WLl2qpqYmSXl0PHg56MaNG94jjzzi/fjHP/bOnz/vfetb3/L+9Kc/ZXtZ5j799FOvtLTUGxoayvZSsuKTTz7xHnvsMU+SF41GPc8rzGPjqaee8n73u9+NeR0K7dg4f/68V1ZW5r355pvelStXvO9///veypUrC+54eOedd7zdu3enbnd0dHgzZ87Mq+MhJ6Pz1ltveWVlZd61a9c8z/O8EydOeE888USWV2Xv448/9mbPnu0tWrTImzFjhvfMM894//rXv7K9LDN3erMtxGPjn//8p+d53pjXodCOjbu92Rbi8fB5f/jDH7yampq8Oh5y8uO1kydPqr6+XiUlJZKkmpoaRSKRLK/KXiQSUXV1tfbv369Tp07J7/dr8+bN2V6WmT179qi5uXnMvkI8NqqqqsbtK7RjIxQKjfnPd+7cOS1ZsqQgj4fPjI6OaufOndqyZUteHQ85+cuhv/jFL3Tjxg21t7en9j3wwAP6xz/+obKysiyuLLsuXbqkqqoqDQ4OFtQFDn0+n6LRqB555JGCPjY+/zp8USEdG6Ojo3r00Uf185//XOfPny/Y46GlpUVHjhzR0aNHNW3atDE/y+XjISfPdPx+/7hLOsyYMUMjIyNZWlFumD9/vpLJpPr7+7O9lKzh2LizQjo2tm/frlmzZqmpqalgj4eOjg61t7frwIED44Ij5fbxkJPRKS8v1yeffDJm39DQkKZPn56lFWVHOBzWgQMHUre7urpUVFSkysrKLK4quzg2bivUY+OLb7aFeDxEo1E1Njaqvb1dy5Ytk5Rfx0NO/D2dL6qrq9OePXtSt6PRqBKJhMrLy7O4KnvLly/Xtm3bVFFRoVu3bmnr1q3auHFj6vPrQsSxcVshHht3erMttOPh+vXrCoVCWrdunRoaGjQ8PCzp9r9l5c3xkO1vMtzJp59+6j3wwAOprz42NTV5oVAoy6vKjpdfftkLBoNeeXm519zc7A0PD2d7Seb0ha8KF+qx8fnXwfMK69gYGRnxli1b5r344ove0NBQahsdHS2o4+Htt9/2JI3botFo3hwPOflFAkk6dOiQGhsbNXPmTBUVFamzszP1/25QWL74D+gcG4Xn4MGD+s53vjNufzQa1alTpzge8kjORkeSrly5omPHjqm+vl7z5s3L9nKQQzg28HkcD/kjp6MDAJhacvLbawCAqYnoAADMEB0AgBmiAwAwQ3QAAGaIDgDADNEBAJghOgAAM0QHAGDm/wCPgC45lWlunwAAAABJRU5ErkJggg=="
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "# 将 784像素点 reshape成 28*28 的像素矩阵显示\n",
    "pyplot.imshow(x_train[0].reshape((28, 28)), cmap=\"gray\")\n",
    "print(x_train.shape)  # 打印大小"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": []
  },
  {
   "cell_type": "markdown",
   "source": [
    "![avatar](./data/img/4.png)"
   ],
   "metadata": {
    "collapsed": false
   }
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "![avatar](./data/img/5.png)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# 2.数据转化\n",
    "## 2.1注意数据需转换成tensor才能参与后续建模训练\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2024-04-07T15:51:02.481034100Z",
     "start_time": "2024-04-07T15:51:02.414641400Z"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "tensor([[0., 0., 0.,  ..., 0., 0., 0.],\n",
      "        [0., 0., 0.,  ..., 0., 0., 0.],\n",
      "        [0., 0., 0.,  ..., 0., 0., 0.],\n",
      "        ...,\n",
      "        [0., 0., 0.,  ..., 0., 0., 0.],\n",
      "        [0., 0., 0.,  ..., 0., 0., 0.],\n",
      "        [0., 0., 0.,  ..., 0., 0., 0.]]) tensor([5, 0, 4,  ..., 8, 4, 8])\n",
      "torch.Size([50000, 784])\n",
      "tensor(0) tensor(9)\n"
     ]
    }
   ],
   "source": [
    "# ndarray转为tensor格式\n",
    "x_train, y_train, x_valid, y_valid = map(torch.tensor, (x_train, y_train, x_valid, y_valid))\n",
    "# n, c = x_train.shape\n",
    "# x_train, x_train.shape, y_train.min(), y_train.max()\n",
    "print(x_train, y_train)  # 打印 训练集特征值 训练集目标值\n",
    "print(x_train.shape)  # 打印训练集特征值的大小\n",
    "print(y_train.min(), y_train.max())  # 打印训练集目标值最小值及最大值 "
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 2.2 torch.nn.functional\n",
    "torch.nn.functional 很多层和函数在这里都会见到\n",
    "torch.nn.functional中有很多功能，后续会常用的。那什么时候使用nn.Module，什么时候使用nn.functional呢？一般情况下，如果模型有可学习的参数，最好用nn.Module，其他情况nn.functional相对更简单一些"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "metadata": {
    "collapsed": true,
    "ExecuteTime": {
     "end_time": "2024-04-07T15:51:53.163527900Z",
     "start_time": "2024-04-07T15:51:53.160777700Z"
    }
   },
   "outputs": [],
   "source": [
    "import torch.nn.functional as F\n",
    "loss_func = F.cross_entropy\n",
    "\n",
    "def model(xb):\n",
    "    return xb.mm(weights) + bias  # wx+b"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2024-04-07T15:53:03.784914400Z",
     "start_time": "2024-04-07T15:53:03.762966400Z"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "tensor(13.4703, grad_fn=<NllLossBackward0>)\n"
     ]
    }
   ],
   "source": [
    "# 对数据进行64分包\n",
    "bs = 64\n",
    "xb = x_train[0:bs]  # a mini-batch from x\n",
    "yb = y_train[0:bs]\n",
    "weights = torch.randn([784, 10], dtype=torch.float, requires_grad=True)\n",
    "bias = torch.zeros(10, requires_grad=True)\n",
    "\n",
    "print(loss_func(model(xb), yb))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 2.3 创建一个model来更简化代码\n",
    "- 必须继承nn.Module且在其构造函数中需调用nn.Module的构造函数\n",
    "- 无需写反向传播函数，nn.Module能够利用autograd自动实现反向传播\n",
    "- Module中的可学习参数可以通过named_parameters()或者parameters()返回迭代器"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "metadata": {
    "collapsed": true,
    "ExecuteTime": {
     "end_time": "2024-04-07T15:56:55.117822900Z",
     "start_time": "2024-04-07T15:56:55.113294300Z"
    }
   },
   "outputs": [],
   "source": [
    "from torch import nn\n",
    "\n",
    "class Mnist_NN(nn.Module):\n",
    "    def __init__(self):\n",
    "        super().__init__()\n",
    "        self.hidden1 = nn.Linear(784, 128)\n",
    "        self.hidden2 = nn.Linear(128, 256)\n",
    "        self.out = nn.Linear(256, 10)\n",
    "        self.dropout = nn.Dropout(0.5)  # 随机杀死神经元，防止过拟合\n",
    "\n",
    "    def forward(self, x):\n",
    "        x = F.relu(self.hidden1(x))\n",
    "        x = self.dropout(x)\n",
    "        x = F.relu(self.hidden2(x))\n",
    "        x = self.dropout(x)\n",
    "        x = self.out(x)\n",
    "        return x\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2024-04-07T15:57:00.901969400Z",
     "start_time": "2024-04-07T15:57:00.891900300Z"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Mnist_NN(\n",
      "  (hidden1): Linear(in_features=784, out_features=128, bias=True)\n",
      "  (hidden2): Linear(in_features=128, out_features=256, bias=True)\n",
      "  (out): Linear(in_features=256, out_features=10, bias=True)\n",
      "  (dropout): Dropout(p=0.5, inplace=False)\n",
      ")\n"
     ]
    }
   ],
   "source": [
    "net = Mnist_NN()\n",
    "print(net)\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "可以打印我们定义好名字里的权重和偏置项"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2024-04-07T15:57:07.914439100Z",
     "start_time": "2024-04-07T15:57:07.905535900Z"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "hidden1.weight Parameter containing:\n",
      "tensor([[-0.0031, -0.0119, -0.0314,  ...,  0.0260, -0.0003,  0.0113],\n",
      "        [ 0.0241,  0.0125,  0.0226,  ..., -0.0124, -0.0155, -0.0165],\n",
      "        [ 0.0221, -0.0246, -0.0319,  ...,  0.0288,  0.0240,  0.0152],\n",
      "        ...,\n",
      "        [ 0.0112, -0.0144,  0.0282,  ..., -0.0288, -0.0054,  0.0102],\n",
      "        [ 0.0041,  0.0245, -0.0115,  ..., -0.0127, -0.0065,  0.0250],\n",
      "        [ 0.0246, -0.0325,  0.0013,  ...,  0.0230, -0.0042, -0.0023]],\n",
      "       requires_grad=True) torch.Size([128, 784])\n",
      "hidden1.bias Parameter containing:\n",
      "tensor([ 0.0138,  0.0195,  0.0129,  0.0275, -0.0350, -0.0201,  0.0026,  0.0053,\n",
      "         0.0052, -0.0287,  0.0276,  0.0288,  0.0035, -0.0088,  0.0055,  0.0224,\n",
      "         0.0336,  0.0300,  0.0049,  0.0098,  0.0258, -0.0301,  0.0215,  0.0203,\n",
      "         0.0126,  0.0018,  0.0014,  0.0238, -0.0147,  0.0334,  0.0281, -0.0069,\n",
      "        -0.0081,  0.0195,  0.0112, -0.0264,  0.0073,  0.0026, -0.0143, -0.0217,\n",
      "        -0.0066,  0.0151, -0.0065, -0.0075,  0.0078, -0.0029,  0.0035, -0.0239,\n",
      "        -0.0129,  0.0008, -0.0129,  0.0034, -0.0103,  0.0017,  0.0356,  0.0081,\n",
      "         0.0137,  0.0036,  0.0115,  0.0035,  0.0002, -0.0027, -0.0276,  0.0251,\n",
      "        -0.0066, -0.0112, -0.0260,  0.0317,  0.0297,  0.0059,  0.0202,  0.0064,\n",
      "        -0.0195,  0.0168, -0.0327,  0.0063,  0.0105, -0.0293,  0.0332, -0.0011,\n",
      "        -0.0027, -0.0350,  0.0273, -0.0062, -0.0320,  0.0004,  0.0331,  0.0193,\n",
      "         0.0340,  0.0242,  0.0105, -0.0064,  0.0352,  0.0140, -0.0150,  0.0172,\n",
      "        -0.0025, -0.0145,  0.0348, -0.0052,  0.0201,  0.0194,  0.0307,  0.0149,\n",
      "         0.0003, -0.0238, -0.0311,  0.0326,  0.0192, -0.0276,  0.0056, -0.0267,\n",
      "        -0.0300, -0.0259,  0.0191, -0.0161,  0.0159, -0.0187,  0.0092, -0.0066,\n",
      "         0.0284,  0.0239,  0.0011, -0.0269,  0.0251, -0.0323,  0.0188, -0.0014],\n",
      "       requires_grad=True) torch.Size([128])\n",
      "hidden2.weight Parameter containing:\n",
      "tensor([[-0.0584,  0.0321, -0.0669,  ...,  0.0381,  0.0220,  0.0858],\n",
      "        [-0.0447,  0.0150,  0.0199,  ..., -0.0725,  0.0155, -0.0392],\n",
      "        [-0.0528, -0.0680,  0.0760,  ...,  0.0274,  0.0759, -0.0522],\n",
      "        ...,\n",
      "        [-0.0866, -0.0635,  0.0625,  ..., -0.0797,  0.0758, -0.0011],\n",
      "        [-0.0343, -0.0066,  0.0692,  ..., -0.0129,  0.0789,  0.0516],\n",
      "        [-0.0053, -0.0130, -0.0745,  ...,  0.0422,  0.0095,  0.0406]],\n",
      "       requires_grad=True) torch.Size([256, 128])\n",
      "hidden2.bias Parameter containing:\n",
      "tensor([ 0.0081, -0.0066,  0.0471, -0.0603, -0.0497, -0.0465,  0.0459,  0.0122,\n",
      "         0.0522,  0.0009,  0.0118,  0.0158,  0.0847, -0.0483, -0.0872,  0.0819,\n",
      "         0.0862,  0.0222,  0.0559,  0.0788,  0.0458,  0.0184, -0.0862, -0.0794,\n",
      "         0.0590,  0.0575, -0.0355,  0.0628,  0.0317,  0.0348,  0.0091, -0.0279,\n",
      "        -0.0744,  0.0678, -0.0544,  0.0798, -0.0631, -0.0792,  0.0494,  0.0128,\n",
      "        -0.0092, -0.0419, -0.0603,  0.0408, -0.0813,  0.0848,  0.0858,  0.0652,\n",
      "        -0.0803, -0.0416,  0.0136, -0.0398, -0.0462, -0.0600,  0.0093,  0.0725,\n",
      "         0.0166,  0.0054,  0.0816, -0.0754,  0.0773,  0.0385,  0.0164,  0.0200,\n",
      "        -0.0233, -0.0313,  0.0223,  0.0378, -0.0112, -0.0231, -0.0430,  0.0321,\n",
      "         0.0185,  0.0355, -0.0831,  0.0241,  0.0128,  0.0410,  0.0095,  0.0374,\n",
      "         0.0448, -0.0264, -0.0002, -0.0104,  0.0845,  0.0139, -0.0465,  0.0570,\n",
      "        -0.0726, -0.0362,  0.0354,  0.0607, -0.0256,  0.0182, -0.0565, -0.0197,\n",
      "         0.0778,  0.0496,  0.0775, -0.0626,  0.0646, -0.0138,  0.0305,  0.0353,\n",
      "         0.0630, -0.0850, -0.0056,  0.0463, -0.0620,  0.0425,  0.0173, -0.0450,\n",
      "         0.0242, -0.0849, -0.0270, -0.0287,  0.0097,  0.0211,  0.0117,  0.0160,\n",
      "         0.0414,  0.0012,  0.0697, -0.0835,  0.0215, -0.0740, -0.0224, -0.0120,\n",
      "        -0.0768,  0.0348, -0.0500, -0.0372, -0.0367, -0.0767, -0.0760,  0.0299,\n",
      "         0.0598, -0.0010,  0.0281,  0.0676,  0.0790,  0.0472, -0.0571,  0.0187,\n",
      "        -0.0444,  0.0196,  0.0830, -0.0733, -0.0304, -0.0668,  0.0228, -0.0247,\n",
      "        -0.0466, -0.0393, -0.0459,  0.0220, -0.0851,  0.0311,  0.0635, -0.0612,\n",
      "        -0.0484,  0.0576, -0.0348, -0.0043, -0.0399,  0.0068, -0.0038,  0.0756,\n",
      "        -0.0325,  0.0879, -0.0635, -0.0180,  0.0819,  0.0114, -0.0035, -0.0313,\n",
      "        -0.0340,  0.0600, -0.0366,  0.0643, -0.0289,  0.0192, -0.0413,  0.0302,\n",
      "         0.0196, -0.0413, -0.0587,  0.0858,  0.0784, -0.0862, -0.0354,  0.0109,\n",
      "         0.0272, -0.0881, -0.0192,  0.0734,  0.0833,  0.0220,  0.0613,  0.0462,\n",
      "         0.0827, -0.0382,  0.0335,  0.0247,  0.0547,  0.0491,  0.0581,  0.0036,\n",
      "        -0.0424,  0.0217, -0.0733,  0.0170, -0.0498, -0.0324,  0.0805,  0.0503,\n",
      "         0.0462, -0.0340, -0.0752,  0.0331,  0.0858, -0.0358,  0.0644,  0.0852,\n",
      "         0.0572, -0.0064, -0.0840, -0.0806, -0.0611,  0.0652, -0.0630, -0.0800,\n",
      "        -0.0429, -0.0864, -0.0398,  0.0727, -0.0879,  0.0854, -0.0743, -0.0728,\n",
      "         0.0426, -0.0273,  0.0497,  0.0401,  0.0190,  0.0726, -0.0541, -0.0803,\n",
      "        -0.0016, -0.0258, -0.0230, -0.0538, -0.0709,  0.0723, -0.0319, -0.0799],\n",
      "       requires_grad=True) torch.Size([256])\n",
      "out.weight Parameter containing:\n",
      "tensor([[-0.0156,  0.0585, -0.0507,  ...,  0.0601,  0.0060,  0.0049],\n",
      "        [-0.0430,  0.0134, -0.0131,  ..., -0.0258, -0.0525, -0.0480],\n",
      "        [-0.0272, -0.0394, -0.0440,  ..., -0.0409, -0.0142, -0.0058],\n",
      "        ...,\n",
      "        [-0.0028, -0.0158,  0.0134,  ..., -0.0211, -0.0288, -0.0203],\n",
      "        [ 0.0234,  0.0020, -0.0588,  ...,  0.0106,  0.0355,  0.0162],\n",
      "        [ 0.0073,  0.0507,  0.0581,  ...,  0.0383,  0.0142, -0.0345]],\n",
      "       requires_grad=True) torch.Size([10, 256])\n",
      "out.bias Parameter containing:\n",
      "tensor([-0.0025,  0.0575,  0.0287, -0.0170, -0.0329, -0.0281, -0.0329, -0.0040,\n",
      "         0.0568, -0.0443], requires_grad=True) torch.Size([10])\n"
     ]
    }
   ],
   "source": [
    "for name, parameter in net.named_parameters():\n",
    "    print(name, parameter, parameter.size())"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 2.4 使用TensorDataset和DataLoader来简化"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 22,
   "metadata": {
    "collapsed": true,
    "ExecuteTime": {
     "end_time": "2024-04-07T15:58:05.895454800Z",
     "start_time": "2024-04-07T15:58:05.890724200Z"
    }
   },
   "outputs": [],
   "source": [
    "# 数据准备\n",
    "from torch.utils.data import TensorDataset\n",
    "from torch.utils.data import DataLoader\n",
    "\n",
    "train_ds = TensorDataset(x_train, y_train)\n",
    "train_dl = DataLoader(train_ds, batch_size=bs, shuffle=True)\n",
    "\n",
    "valid_ds = TensorDataset(x_valid, y_valid)\n",
    "valid_dl = DataLoader(valid_ds, batch_size=bs * 2)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 23,
   "metadata": {
    "collapsed": true,
    "ExecuteTime": {
     "end_time": "2024-04-07T15:58:08.479749900Z",
     "start_time": "2024-04-07T15:58:08.474542600Z"
    }
   },
   "outputs": [],
   "source": [
    "# 定义方法快速的获取数据\n",
    "def get_data(train_ds, valid_ds, bs):\n",
    "    return (\n",
    "        DataLoader(train_ds, batch_size=bs, shuffle=True),\n",
    "        DataLoader(valid_ds, batch_size=bs * 2),\n",
    "    )"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "- 一般在训练模型时加上model.train()，这样会正常使用Batch Normalization和 Dropout\n",
    "- 测试的时候一般选择model.eval()，这样就不会使用Batch Normalization和 Dropout"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 24,
   "metadata": {
    "collapsed": true,
    "ExecuteTime": {
     "end_time": "2024-04-07T15:58:29.970369700Z",
     "start_time": "2024-04-07T15:58:29.960855400Z"
    }
   },
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "# 定义训练方法\n",
    "def fit(steps, model, loss_func, opt, train_dl, valid_dl):\n",
    "    for step in range(steps):\n",
    "        # 训练\n",
    "        model.train()\n",
    "        for xb, yb in train_dl:\n",
    "            loss_batch(model, loss_func, xb, yb, opt)\n",
    "\n",
    "        # 验证\n",
    "        model.eval()\n",
    "        with torch.no_grad():\n",
    "            losses, nums = zip(\n",
    "                *[loss_batch(model, loss_func, xb, yb) for xb, yb in valid_dl]\n",
    "            )\n",
    "        val_loss = np.sum(np.multiply(losses, nums)) / np.sum(nums)  # 平均损失 = 总损失 / 样本个数\n",
    "        print('当前step:' + str(step), '验证集损失：' + str(val_loss))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 44,
   "metadata": {
    "collapsed": true,
    "ExecuteTime": {
     "end_time": "2024-04-07T16:29:29.020032300Z",
     "start_time": "2024-04-07T16:29:29.016861700Z"
    }
   },
   "outputs": [],
   "source": [
    "from torch import optim\n",
    "\n",
    "# 获取模型\n",
    "def get_model():\n",
    "    model = Mnist_NN()\n",
    "    return model, optim.Adam(model.parameters(), lr=0.001)  # SGD or Adam"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 45,
   "metadata": {
    "collapsed": true,
    "ExecuteTime": {
     "end_time": "2024-04-07T16:29:32.220643300Z",
     "start_time": "2024-04-07T16:29:32.215798600Z"
    }
   },
   "outputs": [],
   "source": [
    "# 计算损失\n",
    "def loss_batch(model, loss_func, xb, yb, opt=None):\n",
    "    loss = loss_func(model(xb), yb)  # 计算损失  xb-预测值  yb-真实值\n",
    "\n",
    "    if opt is not None:\n",
    "        loss.backward()  # 反向传播，计算梯度\n",
    "        opt.step()  # 进行更新\n",
    "        opt.zero_grad()  # torch会将迭代梯度进行累加，故在这里要将之前的梯度进行清空。\n",
    "\n",
    "    return loss.item(), len(xb)  # 还要返回总数以计算平均"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 2.5 三行搞定！"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 46,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2024-04-07T16:30:30.758610200Z",
     "start_time": "2024-04-07T16:29:36.225572100Z"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "当前step:0 验证集损失：0.18732796817719935\n",
      "当前step:1 验证集损失：0.13629273714721202\n",
      "当前step:2 验证集损失：0.11832513649128377\n",
      "当前step:3 验证集损失：0.10522761129364372\n",
      "当前step:4 验证集损失：0.10235972150862217\n",
      "当前step:5 验证集损失：0.09990145575255155\n",
      "当前step:6 验证集损失：0.09385116954091936\n",
      "当前step:7 验证集损失：0.09267852237001062\n",
      "当前step:8 验证集损失：0.08713720603697002\n",
      "当前step:9 验证集损失：0.08598045459156856\n",
      "当前step:10 验证集损失：0.0889315923477523\n",
      "当前step:11 验证集损失：0.0849281577335205\n",
      "当前step:12 验证集损失：0.08463994964901358\n",
      "当前step:13 验证集损失：0.0827090285810642\n",
      "当前step:14 验证集损失：0.07994306690935045\n",
      "当前step:15 验证集损失：0.08267633339352906\n",
      "当前step:16 验证集损失：0.08133600186556578\n",
      "当前step:17 验证集损失：0.08016132586877793\n",
      "当前step:18 验证集损失：0.08076238982575014\n",
      "当前step:19 验证集损失：0.07790505788521841\n"
     ]
    }
   ],
   "source": [
    "train_dl, valid_dl = get_data(train_ds, valid_ds, bs)  # 得到数据\n",
    "model, opt = get_model()  # 得到模型和优化器\n",
    "fit(20, model, loss_func, opt, train_dl, valid_dl)  # 进行训练\n",
    "print(\"训练完毕！\")"
   ]
  },
  {
   "cell_type": "markdown",
   "source": [
    "# 3.计算准确率"
   ],
   "metadata": {
    "collapsed": false
   }
  },
  {
   "cell_type": "code",
   "execution_count": 40,
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "128\n",
      "118\n"
     ]
    }
   ],
   "source": [
    "# 通过以下方法可得到输出结果中每个元素属于 1~10 的概率。\n",
    "correct = 0\n",
    "total = 0\n",
    "for xb, yb in valid_dl:\n",
    "    outputs = model(xb)\n",
    "    # print(torch.max(outputs.data, 1))  # 最大值和对应的索引，1表示沿着概率值是1的这个维度越大才是最大值，而不是0的维度。\n",
    "    # print(outputs)  # 输出结果\n",
    "    # print(outputs.shape)  # 输出结果大小\n",
    "    _, predicted = torch.max(outputs.data, 1)  # 最大的值和索引\n",
    "    total += yb.size(0)\n",
    "    print(yb.size(0))  # 验证集目标值包含的元素个数\n",
    "    # 预测值与验证值相等的元素个数，也就是预测正确的个数。item()的作用是将tensor类型转为数值类型\n",
    "    print((predicted == yb).sum().item())  \n",
    "    break"
   ],
   "metadata": {
    "collapsed": false,
    "ExecuteTime": {
     "end_time": "2024-04-07T16:25:29.602409Z",
     "start_time": "2024-04-07T16:25:29.592843700Z"
    }
   }
  },
  {
   "cell_type": "code",
   "execution_count": 47,
   "metadata": {
    "collapsed": true,
    "ExecuteTime": {
     "end_time": "2024-04-07T16:30:48.963290200Z",
     "start_time": "2024-04-07T16:30:48.870882800Z"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Accuracy of the network on the 10000 test images: 97.85%\n"
     ]
    }
   ],
   "source": [
    "# 最终计算准确率\n",
    "correct = 0\n",
    "total = 0\n",
    "for xb, yb in valid_dl:\n",
    "    outputs = model(xb)\n",
    "    _, predicted = torch.max(outputs.data, 1)  # 最大的值和索引\n",
    "    total += yb.size(0)\n",
    "    correct += (predicted == yb).sum().item()\n",
    "print(f\"Accuracy of the network on the 10000 test images: {100 * correct / total}%\")"
   ]
  },
  {
   "cell_type": "markdown",
   "source": [
    "# 4.练习\n",
    "1. 将优化器由SGD（固定学习率）改成Adam（自适应学习率）\n",
    "2. 更改网络层数和神经元个数观察效果\n",
    "3. 计算当前模型的准确率等于多少"
   ],
   "metadata": {
    "collapsed": false
   }
  },
  {
   "cell_type": "markdown",
   "source": [
    "# 5.效率对比\n",
    "- SGD-20次: 86.57%\n",
    "- SGD-100次: 93.32% \n",
    "- Adam-20次: 97.85%\n",
    "- Adam-100次: "
   ],
   "metadata": {
    "collapsed": false
   }
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "outputs": [],
   "source": [],
   "metadata": {
    "collapsed": false
   }
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 1
}
